Skip to main content
Version: 1.0.1

Deep Learning - Deep Vision Classifier

Environment Setup on databricks

-- reinstall horovod based on new version of pytorch

# install cloudpickle 2.0.0 to add synapse module for usage of horovod
%pip install cloudpickle==2.0.0 --force-reinstall --no-deps
import synapse
import cloudpickle

cloudpickle.register_pickle_by_value(synapse)
! horovodrun --check-build
from pyspark.sql.functions import udf, col, regexp_replace
from pyspark.sql.types import IntegerType
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

Read Dataset

def assign_label(path):
num = int(path.split("/")[-1].split(".")[0].split("_")[1])
return num // 81


assign_label_udf = udf(assign_label, IntegerType())
# These files are already uploaded for build test machine
train_df = (
spark.read.format("binaryFile")
.option("pathGlobFilter", "*.jpg")
.load("/tmp/17flowers/train")
.withColumn("image", regexp_replace("path", "dbfs:", "/dbfs"))
.withColumn("label", assign_label_udf(col("path")))
.select("image", "label")
)

display(train_df.limit(100))
test_df = (
spark.read.format("binaryFile")
.option("pathGlobFilter", "*.jpg")
.load("/tmp/17flowers/test")
.withColumn("image", regexp_replace("path", "dbfs:", "/dbfs"))
.withColumn("label", assign_label_udf(col("path")))
.select("image", "label")
)

Training

from horovod.spark.common.store import DBFSLocalStore
from pytorch_lightning.callbacks import ModelCheckpoint
from synapse.ml.dl import *
import uuid

run_output_dir = f"/dbfs/FileStore/test/resnet50/{str(uuid.uuid4())[:8]}"
store = DBFSLocalStore(run_output_dir)

epochs = 10

callbacks = [ModelCheckpoint(filename="{epoch}-{train_loss:.2f}")]
deep_vision_classifier = DeepVisionClassifier(
backbone="resnet50",
store=store,
callbacks=callbacks,
num_classes=17,
batch_size=16,
epochs=epochs,
validation=0.1,
)

deep_vision_model = deep_vision_classifier.fit(train_df)

Prediction

pred_df = deep_vision_model.transform(test_df)
evaluator = MulticlassClassificationEvaluator(
predictionCol="prediction", labelCol="label", metricName="accuracy"
)
print("Test accuracy:", evaluator.evaluate(pred_df))
# Cleanup the output dir for test
dbutils.fs.rm(run_output_dir, True)